//
//  
//
//  Created by Milind Chabbi on 4/27/14.
//
//

#include <stdio.h>
#include <stdint.h>
#include <omp.h>

#define WAIT (0xffffffff)
#define ACQUIRE_PARENT (0xfffffffe)
#define CAS(a, b, c) __sync_val_compare_and_swap(a, b, c)
#define SWAP(a,b) __sync_lock_test_and_set(a, b)

struct QNode;

typedef struct HMCS{
    int threshold;
    struct HMCS * parent;
    struct QNode *  node;
}HMCS;

typedef struct QNode{
    struct QNode * next;
    volatile int status;
    QNode() : status(WAIT), next(NULL) {
      
    }
}QNode;

int GetThresholdAtLevel(int leve){
    return 64;
}

/*
 
 8 cores per CPU
 4 CPUs per node
 2 nodes
 ==>
 levels = 3
 participantsAtLevel = {8, 32, 64};
 
*/

 HMCS ** lockLocations;

 HMCS * LockInit(int tid, int maxThreads, int levels, int * participantsAtLevel){
     // Total locks needed = participantsPerLevel[1] + participantsPerLevel[2] + .. participantsPerLevel[levels-1] + 1
     
     int totalLocksNeeded = 0;
     for (int i=0; i < levels; i++) {
         totalLocksNeeded += maxThreads / participantsAtLevel[i] ;
     }
#pragma omp single 
     {
         lockLocations = new HMCS*[totalLocksNeeded];
     }
     
     // Lock at curLevel l will be initialized by a designated master
     
     int lastLockLocationEnd = 0;
     for(int curLevel = 0 ; curLevel < levels; curLevel++){
         if (tid%participantsAtLevel[curLevel] == 0) {
             // master, initialize the lock
             int lockLocation = lastLockLocationEnd + tid/participantsAtLevel[curLevel];
             lastLockLocationEnd += maxThreads/participantsAtLevel[curLevel];
             HMCS * curLock = new HMCS();
             curLock->threshold = GetThresholdAtLevel(curLevel);
             curLock->parent = NULL;
             curLock->node = NULL;
             lockLocations[lockLocation] = curLock;
             
         }
     }
#pragma omp barrier
     // setup parents
     lastLockLocationEnd = 0;
     for(int curLevel = 0 ; curLevel < levels - 1; curLevel++){
         if (tid%participantsAtLevel[curLevel] == 0) {
             // master, initialize the lock
             int lockLocation = lastLockLocationEnd + tid/participantsAtLevel[curLevel];
             lastLockLocationEnd += maxThreads/participantsAtLevel[curLevel];
             int parentLockLocation = lastLockLocationEnd + tid/participantsAtLevel[curLevel+1];
             lockLocations[lockLocation]->parent = lockLocations[parentLockLocation];
         }
     }
#pragma omp barrier
     // return the lock to each thread
     return lockLocations[tid/participantsAtLevel[0]];
     
}

void Acquire(HMCS * L, QNode *I);
void AcquireParent(HMCS * L, QNode *I) {
    QNode * I2 = new QNode();
    Acquire(L->parent, I2);
    L->node = I2;
    // begining of cohort
    I2->status = 1;
}


void Acquire(HMCS * L, QNode *I) {
    QNode * pred = SWAP(&(L->node), I);
    if(!pred) {
        // I am the first one at this level
        // Acquire at next level if not at the top level
        if(L->parent) {
            AcquireParent(L, I);
        }
        // begining of cohort
        I->status = 1;
    } else {
        pred->next = I;
        
        // Wait for tap from predecessor
        while(I->status == WAIT) ;
        
        if(I->status == ACQUIRE_PARENT) {
            AcquireParent(L, I);
            // beginning of cohort
            I->status = 1;
        }
    }
}

void Release(HMCS * L, QNode *I);

void Release(HMCS * L, QNode *I) {
    if(I->status == L->threshold && L->parent) {
        // reached threshold and have next level
        // release to next level
        Release(L->parent, L->node);
        // we can free L->node now
        delete L->node;

        // Tap successor at this level and ask to spin acquire next level lock
        if(I->next) {
            I->next->status = ACQUIRE_PARENT;
        } else {
            QNode * casVal = CAS(& (L->node), I, NULL);
            if (casVal != I){
                while(I->next == NULL);
                I->next->status = ACQUIRE_PARENT;
            }
        }
    } else {
        // either top level or not reached threshold
        if(I->next) {
            I->next->status = I->status + 1;
        } else {
            QNode * casVal = CAS(&(L->node), I, NULL);
            if (casVal != I){
                while(I->next == NULL);
                I->next->status = I->status + 1;
            }
        }
    }
}



main(){
    int levels = 4;
    int participantsAtLevel[] = {2, 4, 8, 16};
    omp_set_num_threads(16);

#pragma omp parallel 
    {
        int tid = omp_get_thread_num();
        int size = omp_get_num_threads();
        HMCS * hmcs = LockInit(tid, size, levels, participantsAtLevel);
        
        for(int i = 0; i < 0xffff; i++) {
            QNode me;
            Acquire(hmcs, &me);
            printf("Acquired %d!\n", tid);
            Release(hmcs, &me);
            
        }
            

    }
}



